"""
1. git clone -b cotracker2v1_release https://github.com/facebookresearch/co-tracker.git
2. wget https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth -P ./../
CUDA_VISIBLE_DEVICES=0 python eval_ObjMC.py
"""

dataset_dir = "./../dataset"
test_id = "rebuttal/video"
cotracker2_checkpoint = "./../cotracker2.pth"
gt_id = "GT"
VIPSeg_path="./../dataset/VIPSeg"
json_folder = "./VIPSeg/test_traject"
H, W = 320, 576
num_frames = 14

import sys
sys.path.insert(0, "./co-tracker")
from cotracker.utils.visualizer import Visualizer
from cotracker.predictor import CoTrackerPredictor
import cv2
import numpy as np
from PIL import Image
import os 
import torch
import json
import math

#extract resized videos into numpy (np.uint8)
def read_video(videopath):
    cam = cv2.VideoCapture(videopath)
    ctr = 0
    frames = []
    while ctr < num_frames:
        try:
            _, frame = cam.read()
            frame = cv2.resize(frame, (W, H))
            frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            ctr += 1 
        except:
            break
    cam.release()
    #export video
    return np.stack(frames).astype(np.uint8) #F x H x W x 3

def VIPSegReader(file_id = "25_kjuAwy1voxM", trajectory_path = "./DragAnything/data/VIPSeg_Test/", num_frames=14, target_size=(576, 320)):
    video_folder = os.path.join(os.path.join(VIPSeg_path, "imgs/"), file_id+"/") #e.g. 00000482.png
    ann_folder = os.path.join(os.path.join(VIPSeg_path, "panomasks/"), file_id+"/") #e.g. 00000976.png
    
    def sort_frames(frame_name):
        return int(frame_name.split('.')[0])
    #ref: https://github.com/showlab/DragAnything/blob/79355363218a7eb9b3437a31b8604b6d436d9337/utils/extract_semantic_point.py#L157
    image_files = sorted(os.listdir(video_folder), key=sort_frames)
    ann_files =  sorted(os.listdir(ann_folder), key=sort_frames)

    first_frame_file = image_files[0]
    
    original_width, original_height = Image.open(os.path.join(video_folder, first_frame_file)).convert('RGB').size

    images = []
    vis_images = []
    
    original_size = (original_width, original_height)
    print(original_size)
    
    #extract annotaion file
    mask = np.array(Image.open(os.path.join(ann_folder, ann_files[0])))
    ids = [i for i in np.unique(mask)] #list of available mask id

    ID_images=[]
    ids_list={}
    
    with open(os.path.join(json_folder, file_id+".json"), 'r') as json_file:
        trajectory_json = json.load(json_file)
        
    trajectory_list = [] #[[[x1, y1], [x2, y2], .. ], [], []]
    radius_list = [] #[41, 46, 21]
        
    
    for mask_id in ids:
        #load trajectory
        trajectories = trajectory_json[str(mask_id)]
        trajectories = [[int(i[0]/original_size[0]*target_size[0]),int(i[1]/original_size[1]*target_size[1])] for i in trajectories]
        trajectory_list.append(trajectories)

    for i in range(len(trajectory_list)):
        trajectory_list[i] = trajectory_list[i][:num_frames]
    
    return trajectory_list

def extract_traj(test_id):
    video_dir = os.path.join("./../dataset/", test_id) #video folder
    traj_dir = os.path.join("./../dataset/", test_id+"_Traj") #trajectory folder to save
    traj_vis_dir = os.path.join("./../dataset", test_id+"_Traj_Vis") #trajectory visualizer
    
    #query_point: [[], [],.. ] #N x 2(i.e., w, h)
    #if os.path.exists(traj_dir):
    #    print(traj_dir+" exists!!")
    #    return traj_dir
    os.makedirs(traj_dir, exist_ok=True)
    os.makedirs(traj_vis_dir, exist_ok=True)
    
    file_list = os.listdir(video_dir)
    file_list = sorted(file_list)
    print(len(file_list))
    assert(len(file_list)==329 or len(file_list)==330)

    #load cotracker2
    #https://github.com/showlab/DragAnything/blob/79355363218a7eb9b3437a31b8604b6d436d9337/utils/cotracker/Generate_Trajectory_for_VIPSeg.py#L104
    #model = torch.hub.load("facebookresearch/co-tracker", "cotracker2").to("cuda")
    model = CoTrackerPredictor(checkpoint=cotracker2_checkpoint).to("cuda")
    model.support_grid_size = 6 #set sparse grids (e.g. 6x6) as support points
    
    for i, f in enumerate(file_list):
        if not f.endswith(".mp4"):
            continue
        print("reading:", i, f)
        file_id = f[:-4]
        save_path = os.path.join(traj_dir, file_id+".npy")
        #load ground truth trajectory
        trajectory_list = VIPSegReader(file_id=file_id, target_size=(W, H)) #N x F x 2
        query = []        
        for obj in trajectory_list:
            query.append([0, obj[0][0], obj[0][1]]) #[t, x, y]
        query = torch.tensor(query, dtype=torch.float32).reshape((1,-1,3)).to("cuda")

        #load video
        #https://github.com/facebookresearch/co-tracker/tree/9ed05317b794cd177674e681321780614a65e073?tab=readme-ov-file#offline-mode
        video = read_video(os.path.join(video_dir, f))
        video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
        video = video.to("cuda")
        
        #Run cotracker
        #https://github.com/facebookresearch/co-tracker/blob/9ed05317b794cd177674e681321780614a65e073/cotracker/predictor.py#L24
        pred_tracks, pred_visibility = model(video, queries = query) #1 x T(e.g. 14) x N  x 2,  1 x T x N x 1
        pred_tracks_original = pred_tracks
        pred_tracks = pred_tracks[0].permute((1,0,2)) #N x F x 2
        pred_tracks = pred_tracks.cpu().numpy()
        
        #save trajectory file
        np.save(save_path, pred_tracks)

        #visualize trajectory
        #https://github.com/showlab/DragAnything/blob/79355363218a7eb9b3437a31b8604b6d436d9337/utils/cotracker/cotracker/utils/visualizer.py#L78
        vis = Visualizer(save_dir=traj_vis_dir, fps = 7, linewidth = 5)
        vis.visualize(video, pred_tracks_original, pred_visibility, filename = file_id, query_frame=0, save_video = True)
    return traj_dir

#extract trajectories
print("computing trajectories..")
traj_dir = extract_traj(test_id)

#evaluate ObjMc
#ref: https://github.com/showlab/DragAnything/blob/main/utils/Eval_ObjMC/ObjMC.py
def euclidean_distance(point1, point2):
    x1, y1 = point1
    x2, y2 = point2
    distance = math.sqrt((x2 - x1)**2 + (y2 - y1)**2)
    return distance

gt_json = "./data/VIPSeg_Test/trajectories"
prediction_json = traj_dir

pred_list = os.listdir(prediction_json)
assert(len(pred_list)==329)

ED_list = []
total = 0
outside = 0
for i, npy_file in enumerate(pred_list):
    if not npy_file.endswith(".npy"):
        continue
    file_id = npy_file[:-4]
    gt_pred = np.array(VIPSegReader(file_id=file_id, target_size=(W, H))).astype(np.float32) #[N x F x 2]
    trajectory_pred = np.load(os.path.join(traj_dir, npy_file))
    print("calcualte:", i, file_id)
    assert(trajectory_pred.shape[1]==14)
    for j in range(trajectory_pred.shape[0]):
        for k in range(trajectory_pred.shape[1]):
            point1, point2 = gt_pred[j][k], trajectory_pred[j][k]
            total = total + 1
            if 0 <= point1[0] and 0 <= point1[1] and point1[0] < W and point1[1] < H:
                ED = euclidean_distance(point1,point2)
                ED_list.append(ED)
            else:
                #Note: exclude points beyond image space
                outside = outside + 1
assert(len(ED_list)==total-outside)
print("mean euclidean distance", np.mean(ED_list))
print("outside", outside)
print("total", total)